{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 99,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import os.path as osp\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch.utils.data import DataLoader\n",
    "from src.utils.load_cfg import ConfigLoader\n",
    "from src.factories import ModelFactory, LossFactory, InferFactory\n",
    "from src.loaders.base_loader_factory import BaseDataLoaderFactory\n",
    "from trainer import train\n",
    "from tester import test\n",
    "import src.utils.logging as logging\n",
    "import itertools\n",
    "from matplotlib import pyplot as plt\n",
    "import ipdb\n",
    "import src.config as cfg\n",
    "from src.inferences.base_infer import BaseInfer\n",
    "import glob\n",
    "import xml.etree.ElementTree as ET\n",
    "from shutil import copyfile, copytree\n",
    "import pandas as pd\n",
    "import pickle as pkl\n",
    "from src.utils.reid_metrics import reid_evaluate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "./dataset_splits/aic20_vehicle_reid/reid_test.csv\n",
      "./dataset_splits/aic20_vehicle_reid/reid_query.csv\n"
     ]
    }
   ],
   "source": [
    "\"\"\"Prepeare TEST data loaders\"\"\"\n",
    "dataset_cfg = \"configs/dataset_cfgs/aic20_vehicle_reid.yaml\"\n",
    "train_cfg   = \"configs/train_cfgs/aic20_t2_trip_onlloss.yaml\"\n",
    "train_params = ConfigLoader.load_train_cfg(train_cfg)\n",
    "train_params = ConfigLoader.load_train_cfg(train_cfg)\n",
    "\n",
    "common_loader_params = {'batch_size': train_params['batch_size'],'num_workers': 4,}\n",
    "dataset_name, dataset_params = ConfigLoader.load_dataset_cfg(dataset_cfg)\n",
    "loader_fact = BaseDataLoaderFactory(dataset_name, dataset_params, train_params, common_loader_params)\n",
    "test_loaders = loader_fact.test_loaders()\n",
    "gal_ld = test_loaders['gallery']\n",
    "que_ld = test_loaders['query']\n",
    "\"\"\"Prepare GALLERY & QUERY file name\"\"\"\n",
    "que_fname = que_ld.dataset.get_img_names()\n",
    "gal_fname = gal_ld.dataset.get_img_names()\n",
    "que_fname = np.array([int(i) for i in que_fname]).astype(np.int32)\n",
    "gal_fname = np.array([int(i) for i in gal_fname]).astype(np.int32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\" Read test track id file\n",
    "    - input: txt file. test track id, each row contains all imgid in that tracklet\n",
    "    - output: \n",
    "            - (dict) gal_tracks.test_tracks[i] = np.array(img_id)\n",
    "            - (dict) galimg2track. queimg2track[img_id] = track_id\"\"\"\n",
    "with open(\"../aic20_data/origin/test_track_id.txt\") as fi:\n",
    "    lines = fi.readlines()\n",
    "gal_tracks = {}\n",
    "galimg2track = {}\n",
    "for i,track in enumerate(lines):\n",
    "    gal_tracks[i] = np.array(track.strip().split(' ')).astype(np.int) - 1\n",
    "    #note that this is now 0 index!!!\n",
    "    for img_id in gal_tracks[i]:\n",
    "        galimg2track[img_id] = i"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1052, 18290)\n"
     ]
    }
   ],
   "source": [
    "\"\"\"Get raw  distance\"\"\"\n",
    "RAW_OUT_DIR = \"outputs/aic20_t2_trip_onlloss/\"\n",
    "raw_dist = np.load(osp.join(RAW_OUT_DIR,\"dist.npy\"))\n",
    "print(raw_dist.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Using Re-ID"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Reranking is applied!\n",
      "using GPU to compute original distance\n",
      "starting re_ranking\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluating: 100% [############################] loss: ---------- Time:  0:00:00\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Reranked - Validation mAP: 1.0000\n",
      "Reranked - Validation cmc (hard): 1.0000\n"
     ]
    }
   ],
   "source": [
    "raw_que_emb = np.load(osp.join(RAW_OUT_DIR,\"que_emb.npy\"))\n",
    "raw_gal_emb = np.load(osp.join(RAW_OUT_DIR,\"gal_emb.npy\"))\n",
    "tmp = []\n",
    "for track_id in gal_tracks:\n",
    "    tmp.append(np.mean(raw_gal_emb[gal_tracks[track_id]],axis = 0)[:,np.newaxis])\n",
    "avg_gal_emb = np.concatenate(tmp, axis = 1).transpose()\n",
    "raw_que_emb = torch.from_numpy(raw_que_emb)\n",
    "avg_gal_emb = torch.from_numpy(avg_gal_emb)\n",
    "idcs,mAP, cmc, dist_reranked = reid_evaluate(raw_que_emb, avg_gal_emb, \\\n",
    "                np.ones(1052), np.ones(18290), is_reranking = True)\n",
    "print('Reranked - Validation mAP: %.4f' % mAP)\n",
    "print('Reranked - Validation cmc (hard): %.4f' % cmc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 107,
   "metadata": {},
   "outputs": [],
   "source": [
    "que2track_idcs = np.argsort(dist_reranked, axis = 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### QUERY TO TRACKLET"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"calculate \"mean\" dist between query image and tracklets\"\"\"\n",
    "tmp = []\n",
    "for track_id in gal_tracks:\n",
    "    tmp.append(np.mean(raw_dist[:,gal_tracks[track_id]],axis = 1)[:,np.newaxis])\n",
    "que2track_dist = np.concatenate(tmp, axis = 1)\n",
    "que2track_idcs = np.argsort(que2track_dist, axis = 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### SAVE SUBMISSION FILE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "def save_predictions(tracklet_dist, out_file ):\n",
    "    que2track_idcs = np.argsort(tracklet_dist)\n",
    "    rows = []\n",
    "    for idx in range(que2track_idcs.shape[0]):\n",
    "        tmp = []\n",
    "        for track in que2track_idcs[idx,:]:\n",
    "            tmp.append(gal_tracks[track])\n",
    "        rows.append(np.concatenate(tmp,axis=0)[:,np.newaxis])\n",
    "    final_idcs = np.concatenate(rows,axis=1).transpose()[:,:100]\n",
    "    print(\"Saving submission file\")\n",
    "    out_file  = osp.join(\"\", out_file)\n",
    "    np.savetxt(out_file, gal_fname[final_idcs], \n",
    "            delimiter = \" \", fmt = \"%d\", newline='\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 255,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saving submission file\n"
     ]
    }
   ],
   "source": [
    "save_predictions(que2track_dist,\"track2_onlloss_tracklet_2.txt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### VEHICLE TYPE ATTRIBUTE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 108,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1052,) (18290,)\n"
     ]
    }
   ],
   "source": [
    "#Load pickle files\n",
    "gal_veh_type_dict = pkl.load(open(\"../aic20_attributes/test_types.pkl\", \"rb\"))\n",
    "que_veh_type_dict = pkl.load(open(\"../aic20_attributes/query_types.pkl\", \"rb\"))\n",
    "que_type = np.array([que_veh_type_dict[img_id] for img_id in que_fname])\n",
    "gal_type = np.array([gal_veh_type_dict[img_id] for img_id in gal_fname])\n",
    "print(que_type.shape, gal_type.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_majority(a):\n",
    "    (values,counts) = np.unique(a,return_counts=True)\n",
    "    ind=np.argmax(counts)\n",
    "    return values[ind] \n",
    "gal_track_type = []\n",
    "for track_id in gal_tracks:\n",
    "    gal_track_type.append(get_majority(gal_type[gal_tracks[track_id]]))\n",
    "gal_track_type = np.array(gal_track_type)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 110,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "255"
      ]
     },
     "execution_count": 110,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#Manual check with first tracklet\n",
    "first_track_type = gal_track_type[que2track_idcs[:,0]]\n",
    "diff = np.where(first_track_type != que_type)[0]\n",
    "len(diff)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 111,
   "metadata": {},
   "outputs": [],
   "source": [
    "#penalties for vehicle types\n",
    "vehi_type_penal = (que_type[:, np.newaxis] != gal_track_type) * 1.0\n",
    "que2track_final_dist = que2track_init_dist + vehi_type_penal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 112,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1052, 798)"
      ]
     },
     "execution_count": 112,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.argsort(que2track_final_dist).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 259,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saving submission file\n"
     ]
    }
   ],
   "source": [
    "save_predictions(que2track_final_dist, \"track2_tracklet_vehitype_pen1.0.txt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Read all attributes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 113,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Load pickle files\n",
    "que_dict = pkl.load(open(\"../aic20_attributes/aic20_que_imgs_attribs.pkl\", \"rb\"))\n",
    "gal_dict = pkl.load(open(\"../aic20_attributes/aic20_gal_imgs_attribs.pkl\", \"rb\"))\n",
    "\n",
    "def dict2npy(que_key, gal_key):\n",
    "    que_lb = np.array([que_dict[img_id][que_key]['lbl'] for img_id in que_fname])\n",
    "    que_sc = np.array([que_dict[img_id][que_key]['scr'] for img_id in que_fname])\n",
    "    gal_lb = np.array([gal_dict[img_id][gal_key]['lbl'] for img_id in gal_fname])\n",
    "    gal_sc = np.array([gal_dict[img_id][gal_key]['scr'] for img_id in gal_fname])\n",
    "    return {\"que_lb\": que_lb, \"que_sc\": que_sc, \"gal_lb\":gal_lb, \"gal_sc\":gal_sc}\n",
    "\n",
    "attr = {}\n",
    "attr[\"type\"] = dict2npy(\"que_type8\", \"gal_type8\")\n",
    "attr[\"type6\"] = dict2npy(\"que_type6\", \"gal_type6\")\n",
    "attr[\"view\"] = dict2npy(\"que_view\", \"gal_view\")\n",
    "attr[\"wheel\"] = dict2npy(\"que_wheel\", \"gal_wheel\")\n",
    "attr[\"top_view\"] = dict2npy(\"que_topview\", \"gal_topview\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 114,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_majority(a):\n",
    "    (values,counts) = np.unique(a,return_counts=True)\n",
    "    ind=np.argmax(counts)\n",
    "    return values[ind] \n",
    "\n",
    "\"\"\"Generate mask with condition que_scr & gal_scr > thresh\"\"\"\n",
    "def apply_threshold(attr_dict, que_thresh, gal_thresh):\n",
    "    que_scr = attr_dict[\"que_sc\"]\n",
    "    gal_scr = attr_dict[\"gal_sc\"]\n",
    "    que_msk = que_scr > que_thresh\n",
    "    gal_msk = gal_scr > gal_thresh\n",
    "    return que_msk, gal_msk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 115,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total reliable tracklets: 647 / 798 with 8202 / 18920 images\n",
      "Total reliable queries  : 420 / 1052\n",
      "Total reliable tracklets: 760 / 798 with 4319 / 18920 images\n",
      "Total reliable queries  : 349 / 1052\n",
      "Total reliable tracklets: 134 / 798 with 608 / 18920 images\n",
      "Total reliable queries  : 226 / 1052\n",
      "Total reliable tracklets: 774 / 798 with 16172 / 18920 images\n",
      "Total reliable queries  : 933 / 1052\n"
     ]
    }
   ],
   "source": [
    "def que_gal_match_attrib(attr_dict, que_thresh, gal_thresh):\n",
    "    gal_lb = attr_dict[\"gal_lb\"].copy()\n",
    "    que_lb = attr_dict[\"que_lb\"].copy()\n",
    "    que_msk, gal_msk = apply_threshold(attr_dict, que_thresh, gal_thresh)\n",
    "    #set all label with scr < thresh to -1\n",
    "    gal_lb_masked = gal_lb.copy()\n",
    "    gal_lb_masked[np.logical_not(gal_msk)] = -1\n",
    "    \n",
    "    gal_track_label = []\n",
    "    for track_id in gal_tracks: #for all gallery tracks\n",
    "        #get lbls of all imgs in track\n",
    "        lb_of_imgs_in_track = gal_lb_masked[gal_tracks[track_id]]\n",
    "        #remove all labels = -1\n",
    "        lb_of_imgs_in_track = lb_of_imgs_in_track[lb_of_imgs_in_track != -1]\n",
    "\n",
    "        #get majority,\n",
    "        if len(lb_of_imgs_in_track) > 0:\n",
    "            track_lb = get_majority(lb_of_imgs_in_track)\n",
    "        else:\n",
    "            track_lb = -1\n",
    "        gal_track_label.append(track_lb)\n",
    "    #Convert list to npy\n",
    "    gal_track_label = np.array(gal_track_label)\n",
    "    #Mask for gallery tracklet labels (!= -1)\n",
    "    gal_track_mask = gal_track_label != -1\n",
    "    print(\"Total reliable tracklets: %d / 798 with %d / 18920 images\"\n",
    "          % (gal_track_mask.sum(), gal_msk.sum()))\n",
    "    print(\"Total reliable queries  : %d / 1052\" % que_msk.sum())\n",
    "    #Tracklet prediction != query prediction (true: diff, false: NO diff)\n",
    "    diff_mask = que_lb[:, np.newaxis] != gal_track_label\n",
    "#     print(diff_mask.sum()) #Set all unreliable queries to false: NO diff \n",
    "    diff_mask[np.logical_not(que_msk),:] = False\n",
    "#     print(diff_mask.sum()) #Set all unreliable tracklet to false: No diff\n",
    "    diff_mask[:,np.logical_not(gal_track_mask)] = False\n",
    "#     print(diff_mask.sum())\n",
    "    return diff_mask, que_msk, gal_track_mask\n",
    "\n",
    "type_diff_mask, type_que_msk, type_gal_msk    = \\\n",
    "    que_gal_match_attrib(attr[\"type\"],que_thresh=0.9,gal_thresh=0.9)\n",
    "\n",
    "topview_diff_mask,  topview_que_msk, topview_gal_msk = \\\n",
    "que_gal_match_attrib(attr[\"top_view\"],que_thresh=0.9,gal_thresh=0.)\n",
    "\n",
    "wheel_diff_mask,  wheel_que_msk, wheel_gal_msk   = \\\n",
    "que_gal_match_attrib(attr[\"wheel\"],que_thresh=0.6,gal_thresh=0.8)\n",
    "\n",
    "type6_diff_mask, type6_que_msk, type6_gal_msk    = \\\n",
    "    que_gal_match_attrib(attr[\"type6\"],que_thresh=0.9,gal_thresh=0.9)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 116,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0\n",
      "242 {'intermodal': [283, 283]}\n",
      "328 {'wrestoration': [57, 57, 356]}\n",
      "350 {'worrywhe': [217], '400': [185, 683, 217, 508, 217, 771, 758, 185, 185, 683, 217, 508, 217, 771, 758, 185]}\n",
      "352 {'wrestoration': [57, 57, 356], '563543': [57, 57], '4887': [57]}\n",
      "399 {'homecare': [260, 96]}\n",
      "444 {'uctric': [484]}\n",
      "647 {'268': [565, 536, 565, 553, 175, 536, 565]}\n",
      "827 {'apless': [161]}\n",
      "901 {'4954444': [332, 332], '563': [185, 683, 508, 771, 57, 57, 758, 185]}\n",
      "916 {'2659': [166, 553]}\n"
     ]
    }
   ],
   "source": [
    "#initial distance by triplet-reid\n",
    "que2track_init_dist = np.zeros((1052,798)) \n",
    "trip_w = np.arange(798)\n",
    "for i in range(que2track_init_dist.shape[0]):\n",
    "    que2track_init_dist[i,que2track_idcs[i,:]] = trip_w * 1.0\n",
    "\n",
    "que2track_final_dist = que2track_init_dist\n",
    "print(que2track_final_dist[0][686]) # -> check if it is 0.0: oh yeah!\n",
    "\n",
    "#Use vehicle type attribute\n",
    "que2track_final_dist += type_diff_mask * 10.\n",
    "que2track_final_dist += topview_diff_mask * 10.\n",
    "que2track_final_dist += wheel_diff_mask * 10.\n",
    "que2track_final_dist += type6_diff_mask * 10.\n",
    "\n",
    "#\n",
    "text_dict = pkl.load(open(\"../aic20_attributes/aic20_scencetext_attribs.pkl\", \"rb\"))\n",
    "scence_text_scr = np.zeros((1052,798))\n",
    "for img_id in text_dict:\n",
    "    print(img_id, text_dict[img_id])\n",
    "    for selected_idcs in text_dict[img_id].values():\n",
    "        scence_text_scr[img_id - 1, np.array(selected_idcs)] = 1\n",
    "\n",
    "que2track_final_dist += scence_text_scr * -50.\n",
    "\n",
    "wheel_cuong = np.load(\"../aic20_attributes/wheel.npy\")\n",
    "where_are_NaNs = np.isnan(wheel_cuong)\n",
    "wheel_cuong[where_are_NaNs] = 10.0\n",
    "wheel_cuong_msk = wheel_cuong < 0.5\n",
    "# (wheel_cuong_msk == False).sum()\n",
    "que2track_final_dist += wheel_cuong_msk * 10.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saving submission file\n"
     ]
    }
   ],
   "source": [
    "save_predictions(que2track_final_dist, \"t2_track_0.9_topv10.0_type10.0_q0.6g0.8_wheel10.0_type6_10.0_scencetext_wheel_Cuong.txt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 117,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saving submission file\n"
     ]
    }
   ],
   "source": [
    "save_predictions(que2track_final_dist, \"t2_reranked_track_0.9_topv10.0_type10.0_q0.6g0.8_wheel10.0_type6_10.0_scencetext_wheel_Cuong.txt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
